import os
import random
from shutil import copy


def mkdir(file):
    if not os.path.exists(file):
        os.makedirs(file)


# 获取data文件夹下所有文件夹名
file_path = 'data_cat_dog'
data_class = [cla for cla in os.listdir(file_path)]

# 训练集
mkdir('./data/train')
for cla in data_class:
    mkdir('./data/train' + '/' + cla)

# 验证集
mkdir('./data/test')
for cla in data_class:
    mkdir('./data/test' + '/' + cla)

# 划分训练集和验证集比例
split_rate = 0.1

# 遍历所有类别并按比例分为训练集和验证集
for cla in data_class:
    cla_path = file_path + '/' + cla + '/'
    images = os.listdir(cla_path)
    num = len(images)
    eval_index = random.sample(images, k=int(num * split_rate))
    for index, image in enumerate(images):
        if image in eval_index:
            image_path = cla_path + image
            new_path = './data/test' + '/' + cla
            copy(image_path, new_path)

        else:
            image_path = cla_path + image
            new_path = './data/train' + '/' + cla
            copy(image_path, new_path)
        print("\r[{}] processing [{}/{}]".format(cla, index + 1, num), end="")
    print()
print("processing done!")